import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd


params = {
    "axes.labelsize": 22,
    "font.size": 22,
    "legend.fontsize": 20,
    "xtick.labelsize": 20,
    "ytick.labelsize": 20,
    "grid.color": "k",
    "grid.linestyle": ":",
    "grid.linewidth": 1,
}
mpl.rcParams.update(params)

# download country ISO data
country_code = pd.read_csv(
    "https://raw.githubusercontent.com/lukes/ISO-3166-Countries-with-Regional-Codes/refs/heads/master/all/all.csv",
    na_filter=False,
)
dict_country_name = dict(zip(country_code["alpha-2"], country_code["name"]))
dict_country_name["RU"] = "Russia"
dict_country_name["US"] = "United States"
dict_country_name["GB"] = "United Kingdom"

estimate = pd.read_csv("international_migration_flow.csv", na_filter=False)

source = (
    estimate[estimate["migration_month"].str.contains("2022")]
    .groupby("country_from", as_index=False)
    .agg({"num_migrants": "sum"})
)
source = source[source.country_from.isin(dict_country_name.keys())]
source["origin_full_name"] = source.apply(
    lambda x: dict_country_name[x["country_from"]], axis=1
)
source_top5 = (
    source.sort_values(by="num_migrants", ascending=False)
    .head(5)
    .reset_index(drop=True)
)
source_top5["num_migrants"] = source_top5["num_migrants"] / 1e6
source_top5 = source_top5.rename(
    columns={
        "country_from": "Code",
        "num_migrants": "# Migrants (Millions)",
        "origin_full_name": "Country",
    }
)

destination = (
    estimate[estimate["migration_month"].str.contains("2022")]
    .groupby("country_to", as_index=False)
    .agg({"num_migrants": "sum"})
)
destination = destination[destination.country_to.isin(dict_country_name.keys())]
destination["destination_full_name"] = destination.apply(
    lambda x: dict_country_name[x["country_to"]], axis=1
)
destination_top5 = (
    destination.sort_values(by="num_migrants", ascending=False)
    .head(5)
    .reset_index(drop=True)
)
destination_top5["num_migrants"] = destination_top5["num_migrants"] / 1e6
destination_top5 = destination_top5.rename(
    columns={
        "country_to": "Code",
        "num_migrants": "# Migrants (Millions)",
        "destination_full_name": "Country",
    }
)

pair = (
    estimate[estimate["migration_month"].str.contains("2022")]
    .groupby(["country_from", "country_to"], as_index=False)
    .agg({"num_migrants": "sum"})
)
pair = pair[pair.country_from.isin(dict_country_name.keys())]
pair["from"] = pair.apply(lambda x: dict_country_name[x["country_from"]], axis=1)
pair = pair[pair.country_to.isin(dict_country_name.keys())]
pair["to"] = pair.apply(lambda x: dict_country_name[x["country_to"]], axis=1)
pair_top5 = (
    pair.sort_values(by="num_migrants", ascending=False).head(5).reset_index(drop=True)
)
pair_top5["num_migrants"] = pair_top5["num_migrants"] / 1e6
pair_top5 = pair_top5.rename(
    columns={
        "from": "Origin",
        "country_from": "Origin code",
        "to": "Destination",
        "country_to": "Destination code",
        "num_migrants": "# Migrants (Millions)",
    }
)

net = source.merge(
    destination,
    left_on="country_from",
    right_on="country_to",
    how="inner",
    suffixes=("_from", "_to"),
)
net["net"] = net["num_migrants_to"] - net["num_migrants_from"]
net["iso2"] = net["country_from"]

net_top5 = net.sort_values(by="net", ascending=False).head(5)
net_bottom5 = net.sort_values(by="net", ascending=False).tail(5)
net_10 = (
    pd.concat([net_top5, net_bottom5])
    .sort_values(by="net", ascending=False)
    .reset_index(drop=True)
)
net_10["country"] = net_10.apply(lambda x: dict_country_name[x["iso2"]], axis=1)
net_10["net"] = net_10["net"] / 1e6

fig, ax = plt.subplots(figsize=(6, 6))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.get_yaxis().set_ticks([])

plt.barh(range(5, len(net_10)), net_10.net[::-1][-5:], color="#456a83", height=0.7)
plt.barh(range(5), net_10.net[::-1][:5], color="#dd9769", height=0.7)
for index, row in net_10.iterrows():
    if index < 5:
        plt.text(
            0,
            9 - index,
            row["country"] + " ",
            horizontalalignment="right",
            verticalalignment="center",
            fontsize=15,
        )
        plt.text(
            row["net"],
            9 - index,
            "{:.2f} ".format(row["net"]),
            horizontalalignment="right",
            verticalalignment="center",
            fontsize=13,
            color="white",
        )
    else:
        plt.text(
            0, 9 - index, " " + row["country"], verticalalignment="center", fontsize=15
        )
        plt.text(
            row["net"],
            9 - index,
            " {:.2f}".format(row["net"]),
            horizontalalignment="left",
            verticalalignment="center",
            fontsize=13,
            color="white",
        )

plt.xlabel("Net Migration (Millions)")
plt.savefig(
    "figures/fig2b.pdf",
    facecolor="white",
    transparent=False,
    bbox_inches="tight",
    dpi=300,
)


fig, ax = plt.subplots(figsize=(6, 4))
ax.axis("off")
fig.subplots_adjust(left=0.55, bottom=0.001, right=0.999, top=0.86)
plt.barh(
    range(len(destination_top5)),
    destination_top5["# Migrants (Millions)"][::-1],
    color="#456a83",
    height=0.7,
)
for index, row in destination_top5.iterrows():
    plt.text(
        0,
        4 - index,
        row["Country"] + " ",
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=15 * 1.4,
    )
    plt.text(
        row["# Migrants (Millions)"],
        4 - index,
        "{:.2f} ".format(row["# Migrants (Millions)"]),
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=13 * 1.4,
        color="white",
    )
plt.title("Top 5 inflow countries", x=0.0, y=1.02)
plt.savefig(
    "figures/fig2c.pdf",
    facecolor="white",
    transparent=False,
    dpi=300,
)


fig, ax = plt.subplots(figsize=(6, 4))
ax.axis("off")
fig.subplots_adjust(left=0.36, bottom=0.001, right=0.999, top=0.86)
plt.barh(
    range(len(source_top5)),
    source_top5["# Migrants (Millions)"][::-1],
    color="#456a83",
    height=0.7,
)
for index, row in source_top5.iterrows():
    plt.text(
        0,
        4 - index,
        row["Country"] + " ",
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=15 * 1.4,
    )
    plt.text(
        row["# Migrants (Millions)"],
        4 - index,
        "{:.2f} ".format(row["# Migrants (Millions)"]),
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=13 * 1.4,
        color="white",
    )
plt.title("Top 5 outflow countries", x=0.25, y=1.02)
plt.savefig(
    "figures/fig2d.pdf",
    facecolor="white",
    transparent=False,
    dpi=300,
)


fig, ax = plt.subplots(figsize=(6, 4))
ax.axis("off")
fig.subplots_adjust(left=0.75, bottom=0.001, right=0.999, top=0.86)
plt.barh(
    range(len(pair_top5)),
    pair_top5["# Migrants (Millions)"][::-1],
    color="#456a83",
    height=0.7,
)
for index, row in pair_top5.iterrows():
    plt.text(
        0,
        4 - index,
        row["Origin"] + " to " + row["Destination"] + " ",
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=15 * 1.4,
    )
    plt.text(
        row["# Migrants (Millions)"],
        4 - index,
        "{:.2f} ".format(row["# Migrants (Millions)"]),
        horizontalalignment="right",
        verticalalignment="center",
        fontsize=12 * 1.4,
        color="white",
    )
plt.title("Top 5 country pairs", x=-1, y=1.02)
plt.savefig(
    "figures/fig2e.pdf",
    facecolor="white",
    transparent=False,
    dpi=300,
)
